#!/usr/bin/env python3
#
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse
import sys
import os
import re

# Regular expression for pulling out the different pieces of the nvml entry points
# It will match something in the form of:
# funcname, tsapiFuncname, (argument list), "(argument type matching)", arg1[, arg2, ...])
# We place funcname, (argument list), and arg1[, arg2, ...] into groups for use later
preg = re.compile("(nvml\w+),[^)]+(\([^)]+\)),\s+\"[^\"]+\",\s+([^)]+)")

MAX_NVML_ARGS = 20
INJECTION_ARG_COUNT_STR = 'InjectionArgCount'
NVML_RET = 'nvmlReturn_t'

#######################################################################################    
# Globals 
g_key_to_function = {}

#######################################################################################    
# Generated file names
STUB_PATH = 'src/nvml_generated_stubs.cpp'
INJECTION_ARGUMENT_HEADER = 'InjectionArgument.h'
INJECTION_STRUCTS_NAME = 'nvml_injection_structs.h'
INJECTION_STRUCTS_PATH = 'include/%s' % INJECTION_STRUCTS_NAME
INJECTION_ARGUMENT_PATH = 'include/%s' % INJECTION_ARGUMENT_HEADER
INJECTION_CPP_PATH = 'src/InjectionArgument.cpp'
FUNCTION_INFO_PATH = 'src/FunctionInfo.cpp'
FUNCTION_DECLARATIONS_HEADER = 'nvml_generated_declarations.h'
FUNCTION_DECLARATIONS_PATH = 'include/%s' % FUNCTION_DECLARATIONS_HEADER
PASS_THRU_GENERATED_SRC_PATH = 'src/nvml_pass_through_generated.cpp'
KEY_LIST_PATH = 'src/InjectionKeys.cpp'
KEY_LIST_HEADER_PATH = 'include/InjectionKeys.h'
LINUX_DEFS_PATH = 'src/nvml-injection.linux_defs'

#######################################################################################    
AUTO_GENERATED_NOTICE = '/*\n * NOTE: This code is auto-generated by generate_nvml_stubs.py\n * DO NOT EDIT MANUALLY\n */\n\n\n'


#######################################################################################    
skip_functions = [ 'nvmlGetBlacklistDeviceCount', 'nvmlGetBlacklistDeviceInfoByIndex' ]

#######################################################################################    
uint_aliases = [ 
    'nvmlBusType_t',
    'nvmlVgpuTypeId_t',
    'nvmlVgpuInstance_t',
    'nvmlBusType_t',
    'nvmlDeviceArchitecture_t',
    'nvmlPowerSource_t',
    'nvmlAffinityScope_t',
]

def get_version(funcname):
    if funcname and funcname[-3:-1] == '_v':
        return int(funcname[-1])

    return 0

#######################################################################################    
class AllFunctionTypes(object):
    def __init__(self):
        self.all_func_declarations = []
        self.all_argument_type_strs = []
        self.funcname_to_func_type = {}
        self.arg_types_to_func_type = {}

    def AddFunctionType(self, funcname, funcinfo):
#        arg_types = funcinfo.GetArgumentTypes()
        arg_type_str = funcinfo.GetArgumentTypesAsString()
        if arg_type_str not in self.all_argument_type_strs:
            func_declaration = "typedef nvmlReturn_t (*%s_f)%s;" % (funcname, funcinfo.GetArgumentList())
            func_type = "%s_f" % funcname
            self.all_func_declarations.append(func_declaration)
            self.arg_types_to_func_type[arg_type_str] = func_type
            self.funcname_to_func_type[funcname] = func_type
            self.all_argument_type_strs.append(arg_type_str)
#            print("Adding arg_list: '%s'" % arg_types)
        else:
            self.funcname_to_func_type[funcname] = self.arg_types_to_func_type[arg_type_str]
            
    def GetAllFunctionDeclarations(self):
        return self.all_func_declarations

    def GetFunctionType(self, funcname):
        return self.funcname_to_func_type[funcname]

#######################################################################################    
class AllFunctions(object):
    ###################################################################################
    def __init__(self):
        self.func_dict = {}
        self.versioned_funcs = {}

    ###################################################################################
    def AddFunction(self, funcinfo):
        if funcinfo.GetName() not in skip_functions:
            funcname = funcinfo.GetName()
            self.func_dict[funcname] = funcinfo

            version = get_version(funcname)
            if version > 0:
                self.versioned_funcs[funcname] = version

    ###################################################################################
    def GetFunctionDict(self):
        return self.func_dict
    
    ###################################################################################
    def RemoveEarlierVersions(self):
        for funcname in self.versioned_funcs:
            version = self.versioned_funcs[funcname]
            without_version = funcname[:-3]
            for i in range(1, version):
                try:
                    if i == 1:
                        del self.func_dict[without_version]
                    else:
                        to_remove = "%s_v%d" % (without_version, i)
                        del self.func_dict[to_remove]
                except KeyError:
                    pass

        self.versioned_funcs = {}


#######################################################################################    
class FunctionInfo(object):
    ###################################################################################
    def __init__(self, funcname, arg_list, arg_names):
        self.funcname  = funcname.strip()
        self.arg_list  = self.CleanArgList(arg_list)
        self.arg_names = arg_names
        self.arg_types = get_argument_types_from_argument_list(arg_list)
    
    ###################################################################################
    def CleanArgList(self, arg_list):
        # Make sure '*' is always 'T *' and not 'T* ' for our formatter
        tokens = arg_list.split('*')
        new_list = ''
        for token in tokens:
            if not new_list:
                new_list = token
            else:
                if token[0] == ' ':
                    new_list = new_list + '*%s' % token[1:]
                else:
                    new_list = new_list + '*%s' % token

            if token[-1] != ' ':
                new_list += ' '

        return new_list

    ###################################################################################
    def GetName(self):
        return self.funcname

    ###################################################################################
    def GetArgumentList(self):
        return self.arg_list

    ###################################################################################
    def GetArgumentNames(self):
        return self.arg_names

    ###################################################################################
    def GetArgumentTypes(self):
        return self.arg_types

    ###################################################################################
    def GetArgumentTypesAsString(self):
        type_str = ''
        for arg_type in self.arg_types:
            if type_str == '':
                type_str = str(arg_type)
            else:
                type_str = type_str + ",%s" % str(arg_type)
        return type_str
                

#######################################################################################    
def get_true_arg_type(arg_type):
    if is_pointer_type(arg_type):
        if arg_type[:-2] in uint_aliases:
            return 'unsigned int *'
    elif arg_type in uint_aliases:
        return 'unsigned int'

    return arg_type

#######################################################################################    
def remove_extra_spaces(text):
    while text.find('  ') != -1:
        text = text.replace('  ', ' ')
    return text

#######################################################################################    
def get_function_signature(entry_point, first):
    # Remove all line breaks, remove the extra whitespace on the ends, and then get
    # get rid of the parenthesis around the string
    # We are left with something in the form of:
    # funcname, tsapiFuncname, (argument list), "(argument type matching)", arg1, arg2, ...)
    entry_point = entry_point.replace('\n', ' ').strip()[1:-1]
    m = preg.search(entry_point)
    if m:
        return remove_extra_spaces(m.group(1)), remove_extra_spaces(m.group(2)), remove_extra_spaces(m.group(3))
    else:
        if entry_point == "include \"nvml.h":
            pass
        # Ignore errors on the first token because it is everything from before the first entry point
        elif not first: 
            print("no match found in entry point = '%s'" % entry_point)
        return None, None, None

#######################################################################################    
def print_body_line(line, file, extra_indent):
    indent = "    "
    for i in range(0, extra_indent):
        indent += "    "
    file.write("%s%s\n" % (indent, line))

#######################################################################################    
def add_function_type(function_type_dict, funcname, arg_list, function_types):
    func_declaration = "typedef nvmlReturn_t (*%s_f)%s;" % (funcname, arg_list)
    function_types.append(func_declaration)
    
keyPrefixes = [
    'nvmlDeviceGetHandleBy',
    'nvmlDeviceGet',
    'nvmlSystemGet',
    'nvmlDeviceSet',
    'nvmlUnitGet',
    'nvmlUnitSet',
    'nvmlVgpuTypeGet',
    'nvmlVgpuInstanceGet',
    'nvmlVgpuInstanceSet',
    'nvmlGet',
    'nvmlSet',
    'nvmlGpuInstanceGet',
    'nvmlComputeInstanceGet',
    'nvmlDeviceClear',
    'nvmlDeviceFreeze',
    'nvmlDeviceModify',
    'nvmlDeviceQuery',
    'nvmlDeviceCreate',
    'nvmlDeviceReset',
    'nvmlDeviceIs',
    'nvmlDevice',
]

gpmPrefix = 'nvmlGpm'

#######################################################################################    
def get_suffix_if_match(funcname, prefix):
    if funcname[:len(prefix)] == prefix:
        key = funcname[len(prefix):]
        return key

    return None

#######################################################################################    
def get_function_info_from_name(funcname):
    key = None
    version = 1
    matched = False

    for prefix in keyPrefixes:
        key = get_suffix_if_match(funcname, prefix)
        if key:
            break

    if not key:
        key = get_suffix_if_match(funcname, gpmPrefix)
        if key:
            if key[-3:] == 'Get':
                key = key[:-3]
        else:
            print("Can't get key for %s" % funcname)
        
    # Check for version at the end
    if key:
        if key[-3:-1] == '_v':
            version = int(key[-1])
            key = key[:-3]

        if key in g_key_to_function:
            func_list = "%s, %s" % (g_key_to_function[key], funcname)
#            print("Key %s maps to two functions: %s" % (key, func_list))
            g_key_to_function[key] = func_list
        else:
            g_key_to_function[key] = funcname

    return key, version

#######################################################################################    
def check_and_write_get_string_body(stub_file, key, arg_types, arg_names):
    if len(arg_types) != 3 or arg_types[1] != CHAR:
        return False

    if is_pointer(arg_types[0]):
        return False

    if arg_types[2] != UINT and arg_types[2] != UINT_PTR:
        return False
    
    # InjectionNvml::GetString will return a std::string associated with two keys
    print_body_line("InjectionArgument arg(%s);" % arg_names[0], stub_file, 1)
    print_body_line("std::string buf = InjectedNvml->GetString(arg, \"%s\");" % (key), stub_file, 1)

    if arg_types[2] == UINT:
        print_body_line("snprintf(%s, %s, \"%s\", buf.c_str());" % (arg_names[1], arg_names[2], '%s'), stub_file, 1)
    elif arg_types[2] == UINT_PTR:
        print_body_line("snprintf(%s, *%s, \"%s\", buf.c_str());" % (arg_names[1], arg_names[2], '%s'), stub_file, 1)

    return True

#######################################################################################    
def is_pointer(arg_type):
    if arg_type[-1] == '*':
        return True

    return False

CONST_CHAR = 'const char *'
CHAR = 'char *'
NVML_DEVICE = 'nvmlDevice_t'
NVML_DEVICE_PTR = 'nvmlDevice_t *'
UINT_PTR = 'unsigned int *'
CLOCKTYPE = 'nvmlClockType_t'
UINT = 'unsigned int'
UINT_PTR = 'unsigned int *'
VGPUTYPEID = 'nvmlVgpuTypeId_t'
UNIT = 'nvmlUnit_t'
VGPU_INSTANCE = 'nvmlVgpuInstance_t'

#######################################################################################    
def print_ungenerated_function(funcname, arg_types):
    arg_type_string = ''
    for arg_type in arg_types:
        if len(arg_type_string):
            arg_type_string = arg_type_string + ',%s' % arg_type
        else:
            arg_type_string = arg_type
    #print("Not generated: %s with (%d) arg_types: %s" % (funcname, len(arg_types), arg_type_string))

#######################################################################################    
def generate_getter_functions(stub_file, funcname, arg_list, arg_types, arg_names, justifyLen):
    generated = True
    key, version = get_function_info_from_name(funcname)

    if funcname == "nvmlDeviceGetFieldValues":
        print_body_line("if (%s == nullptr)" % arg_names[2], stub_file, 1)
        print_body_line("{", stub_file, 1)
        print_body_line("return NVML_ERROR_INVALID_ARGUMENT;", stub_file, 2)
        print_body_line("}\n", stub_file, 1)
        print_body_line("InjectedNvml->GetFieldValues(%s, %s, %s);" % (arg_names[0], arg_names[1], arg_names[2]), stub_file, 1)
    elif len(arg_types) == 2 and arg_types[1] == NVML_DEVICE_PTR:
        # InjectedNvml::GetNvmlDevice returns an nvmlDevice_t and accepts a string identifier, a string describing 
        # the identifier
        print_body_line("InjectionArgument identifier(%s);" % arg_names[0], stub_file, 1)
        print_body_line("*%s = InjectedNvml->GetNvmlDevice(identifier, \"%s\");" % (arg_names[1], key), stub_file, 1)
#        print("GetNvmlDevice: %s - %s" % (funcname, key))
    elif len(arg_types) >= 3 and arg_types[1] == CLOCKTYPE:
        if len(arg_types) >= 5: # need to write code to handle the 5 argument version
            generated = False

        if len(arg_types) == 4:
            # InjectedNvml::GetClock returns an unsigned int and receives an nvmlDevice_t, a clock type, and a clock ID
            print_body_line("*%s = InjectedNvml->GetClock(%s, %s, %s);" % (arg_names[3].ljust(justifyLen-1), arg_names[0], arg_names[1], arg_names[2]), stub_file, 1)
        elif len(arg_types) == 3:
            print_body_line("*%s = InjectedNvml->GetClockInfo(%s, \"%s\", %s);" % (arg_names[2].ljust(justifyLen-1), arg_names[0], key, arg_names[1]), stub_file, 1)
    elif len(arg_types) == 2 and arg_types[0] == NVML_DEVICE:
        print_body_line("InjectionArgument arg(%s);" % (arg_names[1]), stub_file, 1)
        # SimpleDeviceGet() is a function that accepts an nvmlDevice_t and a function name, and 
        # returns an InjectionArgument populated with the associated value
        print_body_line("arg.SetValueFrom(InjectedNvml->SimpleDeviceGet(%s, \"%s\"));" % (arg_names[0], key), stub_file, 1)
    elif check_and_write_get_string_body(stub_file, key, arg_types, arg_names):
        pass
        #print("GetString: %s - %s" % (funcname, key))
    elif len(arg_types) == 3 and arg_types[0] == NVML_DEVICE and is_pointer(arg_types[2]):
        if is_pointer(arg_types[1]):
            print_body_line("std::vector<InjectionArgument> values;", stub_file, 1)
            print_body_line("values.push_back(InjectionArgument(%s));" % arg_names[1], stub_file, 1)
            print_body_line("values.push_back(InjectionArgument(%s));" % arg_names[2], stub_file, 1)
            print_body_line("CompoundValue cv(values);", stub_file, 1)
            # GetCompoundValue will set the variables in the vector, reflecting the order they're supplied in
            print_body_line("InjectedNvml->GetCompoundValue(%s, \"%s\", cv);" % (arg_names[0], key), stub_file, 1)
            #print("Get compound value is covering: %s" % funcname)
        else:
            print_body_line("InjectionArgument output(%s);" % arg_names[2], stub_file, 1)
            print_body_line("InjectionArgument arg(%s);" % arg_names[1], stub_file, 1)
            print_body_line("output.SetValueFrom(InjectedNvml->DeviceGetWithExtraKey(%s, \"%s\", arg));" % (arg_names[0], key), stub_file, 1)
#            print("DeviceGetWithExtraKey for %s" % funcname)
    elif len(arg_types) == 1 and is_pointer(arg_types[0]):
        print_body_line("InjectionArgument arg(%s);" % arg_names[0], stub_file, 1)
        print_body_line("arg.SetValueFrom(InjectedNvml->ObjectlessGet(\"%s\"));" % (key), stub_file, 1)
#        print("ObjectlessGet for %s" % funcname)
    elif len(arg_types) == 2 and arg_types[0] == CHAR and arg_types[1] == UINT:
        lhand = "std::string str"
        print_body_line("%s = InjectedNvml->ObjectlessGet(\"%s\").AsString();" % (lhand.ljust(justifyLen), key), stub_file, 1)
        print_body_line("snprintf(%s, %s, \"%s\", str.c_str());" % (arg_names[0], arg_names[1], "%s"), stub_file, 1)
        #print("GetString: %s - %s" % (funcname, key))
    elif len(arg_types) == 2 and arg_types[0] == UNIT:
        print_body_line("InjectionArgument output(%s);" % arg_names[1], stub_file, 1)
        print_body_line("output.SetValueFrom(InjectedNvml->UnitGet(%s, \"%s\"));" % (arg_names[0], key), stub_file, 1)
    elif len(arg_types) == 2 and arg_types[0] == VGPU_INSTANCE:
        print_body_line("InjectionArgument output(%s);" % arg_names[1], stub_file, 1)
        print_body_line("output.SetValueFrom(InjectedNvml->VgpuInstanceGet(%s, \"%s\"));" % (arg_names[0], key), stub_file, 1)
    elif len(arg_types) == 2 and arg_types[0] == VGPUTYPEID:
        print_body_line("InjectionArgument output(%s);" % arg_names[1], stub_file, 1)
        print_body_line("output.SetValueFrom(InjectedNvml->GetByVgpuTypeId(%s, \"%s\"));" % (arg_names[0], key), stub_file, 1)
    else:
        print_ungenerated_function(funcname, arg_types)
        generated = False

        
    if generated:
        print_body_line('return NVML_SUCCESS;', stub_file, 1)

    return generated

#######################################################################################    
def is_getter(funcname):
    if funcname.find("Get") != -1:
        return True

    return False

#######################################################################################    
def is_setter(funcname):
    if funcname.find("Set") != -1:
        return True

    return False

#######################################################################################    
def generate_setter_functions(stub_file, funcname, arg_list, arg_types, arg_names):
    generated = True
    key, version = get_function_info_from_name(funcname)

    if len(arg_types) == 2 and arg_types[0] == NVML_DEVICE:
        print_body_line("InjectionArgument arg(%s);" % (arg_names[1].strip()), stub_file, 1)
        print_body_line("InjectedNvml->SimpleDeviceSet(%s, \"%s\", arg);" % (arg_names[0], key), stub_file, 1)
    elif len(arg_types) == 3 and arg_types[0] == NVML_DEVICE:
        if funcname == 'nvmlDeviceSetFanSpeed_v2' or funcname == 'nvmlDeviceSetTemperatureThreshold':
            print_body_line("InjectionArgument extraKey(%s);" % arg_names[1], stub_file, 1)
            print_body_line("InjectionArgument value(%s);" % arg_names[2], stub_file, 1)
            print_body_line("InjectedNvml->DeviceSetWithExtraKey(%s, \"%s\", extraKey, value);" % (arg_names[0], key), stub_file, 1)
        else:
            print_body_line("std::vector<InjectionArgument> values;", stub_file, 1)
            print_body_line("values.push_back(InjectionArgument(%s));" % arg_names[1].strip(), stub_file, 1)
            print_body_line("values.push_back(InjectionArgument(%s));" % arg_names[2].strip(), stub_file, 1)
            print_body_line("CompoundValue cv(values);", stub_file, 1)
            print_body_line("InjectedNvml->DeviceSetCompoundValue(%s, \"%s\", cv);" % (arg_names[0], key), stub_file, 1)
    else:
        generated = False
        print_ungenerated_function(funcname, arg_types)

    if generated:
        print_body_line("return NVML_SUCCESS;", stub_file, 1)

    return generated

cant_generate = [ 
    'nvmlDeviceGetVgpuMetadata',
    'nvmlDeviceSetDriverModel', # Windows only
    'nvmlDeviceSetMigMode', # Too unique - requires specific checks
] 
#######################################################################################    
def generate_injection_function(stub_file, funcname, arg_list, arg_types, arg_names, justifyLen):
    if funcname in cant_generate:
        return False
    
    generated = False

    if is_getter(funcname):
        generated = generate_getter_functions(stub_file, funcname, arg_list, arg_types, arg_names, justifyLen)
    elif is_setter(funcname):
        generated = generate_setter_functions(stub_file, funcname, arg_list, arg_types, arg_names)
    else:
        print_ungenerated_function(funcname, arg_types)

    return generated

#######################################################################################    
def write_function_definition_start(fileHandle, funcname, arg_list):
    first_part = "%s %s" % (NVML_RET, funcname)
    line = "%s%s" % (first_part, arg_list)
    line = remove_extra_spaces(line).strip()
    if len(line) <= 120:
        fileHandle.write("%s\n{\n" % line)
    else:
        tokens = arg_list.split(',')
        count = len(tokens)
        fileHandle.write("%s%s,\n" % (first_part, tokens[0]))
        if count > 2:
            index = 1
            while index < count - 1:
                fileHandle.write("%s%s,\n" % (" ".ljust(len(first_part)), tokens[index]))
                index = index + 1
        fileHandle.write("%s %s\n{\n" % (" ".ljust(len(first_part)), tokens[-1].strip()))

#######################################################################################    
def write_function(stub_file, funcinfo, all_functypes):
    funcname = funcinfo.GetName()
    arg_list = funcinfo.GetArgumentList()
    arg_names = funcinfo.GetArgumentNames()
    arg_types = funcinfo.GetArgumentTypes()
    key, version = get_function_info_from_name(funcname)

    generated = False
    write_function_definition_start(stub_file, funcname, arg_list)

    # Write the body
    print_body_line("if (GLOBAL_PASS_THROUGH_MODE)", stub_file, 0)
    print_body_line("{", stub_file, 0)
    print_body_line("auto PassThruNvml = PassThruNvml::GetInstance();", stub_file, 1)
    print_body_line("if (PassThruNvml->IsLoaded(__func__) == false)", stub_file, 1)
    print_body_line("{", stub_file, 1)
    print_body_line("PassThruNvml->LoadFunction(__func__);", stub_file, 2)
    print_body_line("}", stub_file, 1)
    print_body_line("return NVML_ERROR_NOT_SUPPORTED;", stub_file, 1)
#    print_body_line("auto func = reinterpret_cast<decltype(%s)>(PassThruNvml->GetFunction(__func__));" % (funcname), stub_file, 1)
    #print_body_line("// auto func = (decltype(%s))(PassThruNvml->GetFunction(__func__));" % (funcname), stub_file, 1)
    #print_body_line("// return func(%s);" % (arg_names), stub_file, 1)
    print_body_line("}", stub_file, 0)
    print_body_line("else", stub_file, 0)
    print_body_line("{", stub_file, 0)

    unstripped_arguments = arg_names.split(",")
    arguments = []
    for arg in unstripped_arguments:
        arguments.append(arg.strip())

    start = "auto InjectedNvml"
    print_body_line("%s = InjectedNvml::GetInstance();" % start, stub_file, 1)
    if generate_injection_function(stub_file, funcname, arg_list, arg_types, arguments, len(start)):
        generated = True
    else:
        useDevice = arg_types[0] == NVML_DEVICE
        first = True
        print_body_line("std::vector<InjectionArgument> args;", stub_file, 1)
        for argument in arguments:
            if first and useDevice:
                pass
            else:
                print_body_line("args.push_back(InjectionArgument(%s));" % argument.strip(), stub_file, 1)
        stub_file.write("\n")
        print_body_line("if (InjectedNvml->IsGetter(__func__))", stub_file, 1)
        print_body_line("{", stub_file, 1)
        if useDevice:
            print_body_line("return InjectedNvml->DeviceGetWrapper(__func__, \"%s\", %s, args);" % (key, arguments[0]), stub_file, 2)
        else:
            print_body_line("return InjectedNvml->GetWrapper(__func__, args);", stub_file, 2)
        print_body_line("}", stub_file, 1)
        print_body_line("else", stub_file, 1)
        print_body_line("{", stub_file, 1)
        if useDevice:
            print_body_line("return InjectedNvml->DeviceSetWrapper(__func__, \"%s\", %s, args);" % (key, arguments[0]), stub_file, 2)
        else:
            print_body_line("return InjectedNvml->SetWrapper(__func__, args);", stub_file, 2)
        print_body_line("}", stub_file, 1)
    print_body_line("}", stub_file, 0)
    print_body_line("return NVML_SUCCESS;", stub_file, 0)

    # Write the end of the function
    stub_file.write("}\n\n")
    return generated

#######################################################################################    
def write_declarations_file(function_declarations, output_dir):
    declFilePath = '%s/%s' % (output_dir, FUNCTION_DECLARATIONS_PATH)
    with open(declFilePath, 'w') as decl_file:
        decl_file.write("#pragma once\n\n")
        
        decl_file.write(AUTO_GENERATED_NOTICE)
        decl_file.write('#include <nvml.h>\n\n')

        decl_file.write("#define MAX_NVML_ARGS %d\n" % MAX_NVML_ARGS)
        decl_file.write("typedef struct\n")
        decl_file.write("{\n")
        print_body_line('const char *funcname;', decl_file, 0)
        print_body_line('unsigned int argCount;', decl_file, 0)
        print_body_line('injectionArgType_t argTypes[MAX_NVML_ARGS];', decl_file, 0)
        decl_file.write('} functionInfo_t;\n\n')

        decl_file.write('// clang-format off\n')
        for declaration in function_declarations:
            decl_file.write("%s\n" % declaration)

#######################################################################################    
def write_stub_file_header(stub_file):
    stub_file.write(AUTO_GENERATED_NOTICE)
    stub_file.write("#include \"InjectedNvml.h\"\n")
    stub_file.write("#include \"nvml.h\"\n")
    stub_file.write("#include \"%s\"\n\n" % FUNCTION_DECLARATIONS_HEADER)
    stub_file.write("#include \"PassThruNvml.h\"\n\n")
    stub_file.write("#ifdef __cplusplus\n")
    stub_file.write("extern \"C\"\n{\n#endif\n\n")
    stub_file.write("bool GLOBAL_PASS_THROUGH_MODE = false;\n\n")

#######################################################################################    
def get_argument_types_from_argument_list(arg_list):
    argument_types = []
    arg_list = arg_list.strip()
    if arg_list[0] == '(':
        arg_list = arg_list[1:]
    if arg_list[-1] == ')':
        arg_list = arg_list[:-1]
    arguments = arg_list.split(',')

    for argument in arguments:
        words = argument.strip().split(' ')
        arg_type = words[0]
        if len(words) == 2:
            arg_name = words[1].strip()[0]
        else:
            for i in range(1, len(words)-1):
                arg_type += ' %s' % words[i]
            arg_name = words[len(words)-1]

        if arg_name[0] == '*':
            arg_type += ' *'
        elif arg_type[-1] == '*' and arg_type[-2] != ' ':
            arg_type = arg_type[:-1] + ' *'

        argument_types.append(arg_type)

    return argument_types

#######################################################################################    
def build_argument_type_list(arg_list, all_argument_types):
    argument_types = get_argument_types_from_argument_list(arg_list)

    for arg_type in argument_types:
        check_type = get_true_arg_type(arg_type)
        if check_type not in all_argument_types:
            all_argument_types.append(check_type)

    return argument_types

#######################################################################################    
def is_pointer_type(arg_type):
    return arg_type[-2:] == ' *'

#######################################################################################    
def is_nvml_enum(arg_type):
    return arg_type[:4] == 'nvml'

#######################################################################################    
def ends_with_t(arg_type):
    return arg_type[-2:] == '_t'

#######################################################################################    
def transform_arg_type(arg_type, arg_type_dict):
    originalType = arg_type
    if arg_type == 'char *':
        arg_type_dict[originalType] = ['str', True, 'Str']
        return 'str', True
    elif arg_type == 'const char *':
        arg_type_dict[originalType] = ['const_str', True, 'ConstStr']
        return 'const_str', True

    isPtr = False
    if is_pointer_type(arg_type):
        isPtr = True
        # Remove the ' *' to generate the name
        arg_type = arg_type[:-2] 

    arg_type_name = ''
    arg_type_as_suffix = ''
    if arg_type == 'nvmlBAR1Memory_t':
        arg_type_name = 'bar1Memory'
        arg_type_as_suffix = 'BAR1Memory'
    elif is_nvml_enum(arg_type):
        # Handle nvml enum type e.g. nvmlDevice_t => device
        if ends_with_t(arg_type):
            arg_type_name = '%s%s' % (arg_type[4].lower(), arg_type[5:-2])
            arg_type_as_suffix = '%s' % (arg_type[4:-2])
        else:
            arg_type_name = '%s%s' % (arg_type[4].lower(), arg_type[5:])
            arg_type_as_suffix = '%s' % (arg_type[4:])
    else:
        words = arg_type.strip().split(' ')
        # Make the variable name the first letter of each word
        for word in words:
            arg_type_name += word[0]
            if word == 'unsigned':
                arg_type_as_suffix = 'U'
            else:
                arg_type_as_suffix += '%s%s' % (word[0].upper(), word[1:])

    if isPtr:
        arg_type_name += 'Ptr'
        arg_type_as_suffix += 'Ptr'

    arg_type_dict[originalType] = [arg_type_name, isPtr, arg_type_as_suffix]
   
    return arg_type_name, isPtr

#######################################################################################    
def get_enum_name(arg_type):
    prefix = 'INJECTION_'
    isPointer = False
    suffix = ''

    if is_pointer_type(arg_type):
        isPointer = True
        arg_type = arg_type[:-2]

    if is_nvml_enum(arg_type):
        if isPointer:
            suffix = '_PTR'

        if ends_with_t(arg_type):
            enum_name = '%s%s%s' % (prefix, arg_type[4:-2].upper(), suffix)
        else:
            enum_name = '%s%s%s' % (prefix, arg_type[4:].upper(), suffix)
    else:
        words = arg_type.strip().split(' ')
        if len(words) == 1:
            if isPointer:
                suffix = '_PTR'
            enum_name = '%s%s%s' % (prefix, arg_type.upper(), suffix)
        else:
            enum_name = prefix
            for word in words:
                if word == 'unsigned':
                    enum_name += 'U'
                else:
                    enum_name += '%s_' % word.upper()

            if isPointer:
                enum_name += 'PTR'
            else:
                enum_name = enum_name[:-1]

    return enum_name

#######################################################################################    
def print_memcpy(fileHandle, indentLevel, destName, srcName, destIsRef, srcIsRef):
    firstPos = ''
    secondPos = ''
    thirdPos = ''

    if destIsRef:
        firstPos = '&m_value.%s' % destName
        thirdPos = 'sizeof(m_value.%s)' % destName
    else:
        firstPos = 'm_value.%s' % destName
        thirdPos = 'sizeof(*m_value.%s)' % destName

    if srcIsRef:
        secondPos = '&other.m_value.%s' % srcName
    else:
        secondPos = 'other.m_value.%s' % srcName

    line = 'memcpy(%s, %s, %s);' % (firstPos, secondPos, thirdPos)

    size = len(line) + (4 * indentLevel) + 4
    if size <= 120:
        print_body_line(line, fileHandle, indentLevel)
    elif size < 124:
        print_body_line('memcpy(', fileHandle, indentLevel)
        print_body_line('%s, %s, %s);' % (firstPos, secondPos, thirdPos), fileHandle, indentLevel + 1)
    else:
        print_body_line('memcpy(%s,' % firstPos, fileHandle, indentLevel)
        print_body_line('   %s,' % secondPos, fileHandle, indentLevel + 1)
        print_body_line('   %s);' % thirdPos, fileHandle, indentLevel + 1)

#######################################################################################    
def print_equals_and_set(fileHandle, indentLevel, lhName, rhName, lIsPtr, rIsPtr):
    left_hand = ''
    if lIsPtr:
        left_hand = "*this->m_value.%s" % lhName
    else:
        left_hand = "this->m_value.%s" % lhName

    if not rhName:
        rhName = lhName

    if rIsPtr:
        print_body_line("%s = *other.m_value.%s;" % (left_hand, rhName), fileHandle, indentLevel)
    else:
        print_body_line("%s = other.m_value.%s;" % (left_hand, rhName), fileHandle, indentLevel)

    print_body_line("%s = true;" % "set".ljust(len(left_hand)), fileHandle, indentLevel)

#######################################################################################    
def write_string_case_entry(injectionCpp):
    print_body_line('case INJECTION_STRING:', injectionCpp, 1)
    print_body_line('{', injectionCpp, 1)
    print_body_line('if (other.m_type == INJECTION_STRING)', injectionCpp, 2)
    print_body_line('{', injectionCpp, 2)
    print_body_line('this->m_str = other.m_str;', injectionCpp, 3)
    print_body_line('set         = true;', injectionCpp, 3)
    print_body_line('}', injectionCpp, 2)
    print_body_line('else if (other.m_type == INJECTION_CHAR_PTR && other.m_value.str != nullptr)', injectionCpp, 2)
    print_body_line('{', injectionCpp, 2)
    print_body_line('this->m_str = other.m_value.str;', injectionCpp, 3)
    print_body_line('set         = true;', injectionCpp, 3)
    print_body_line('}', injectionCpp, 2)
    print_body_line('else if (other.m_type == INJECTION_CONST_CHAR_PTR && other.m_value.const_str != nullptr)', injectionCpp, 2)
    print_body_line('{', injectionCpp, 2)
    print_body_line('this->m_str = other.m_value.const_str;', injectionCpp, 3)
    print_body_line('set         = true;', injectionCpp, 3)
    print_body_line('}', injectionCpp, 2)
    print_body_line('break;', injectionCpp, 2)
    print_body_line('}', injectionCpp, 1)

#######################################################################################    
def write_case_entry(enum_name, enum_name_to_type_dict, injectionCpp, arg_type_dict):
    if enum_name == 'INJECTION_CONST_CHAR_PTR' or enum_name == 'INJECTION_CONST_NVMLGPUINSTANCEPLACEMENT_T_PTR':
        # Don't support setting const pointers
        return

    print_body_line('case %s:' % enum_name, injectionCpp, 1)
    print_body_line('{', injectionCpp, 1)
    print_body_line('if (other.m_type == %s)' % enum_name, injectionCpp, 2)
    print_body_line('{', injectionCpp, 2)
    arg_type = enum_name_to_type_dict[enum_name]
    argInfoTuple = arg_type_dict[arg_type]
    structVarName = argInfoTuple[0]
    isPtr = argInfoTuple[1]
    if arg_type[:4] == 'nvml':
        # Handle ptr types
        if isPtr:
            print_memcpy(injectionCpp, 3, structVarName, structVarName, False, False)
        else:
            print_memcpy(injectionCpp, 3, structVarName, structVarName, True, True)
        print_body_line('set = true;', injectionCpp, 3)
    elif enum_name == 'INJECTION_CHAR_PTR':
        print_body_line('return NVML_ERROR_INVALID_ARGUMENT;', injectionCpp, 3)
        pass
    else:
        print_equals_and_set(injectionCpp, 3, structVarName, structVarName, isPtr, isPtr)

    print_body_line('}', injectionCpp, 2) # close generated if statement

    if not isPtr:
        # Add setting a non-pointer from a pointer to the same arg
        ptrVersion = '%s_PTR' % enum_name
        structVarPtrVersion = '%sPtr' % structVarName
        if ptrVersion in enum_name_to_type_dict:
            print_body_line('else if (other.m_type == %s)' % ptrVersion, injectionCpp, 2)
            print_body_line('{', injectionCpp, 2)
            if arg_type[:4] == 'nvml':
                print_memcpy(injectionCpp, 3, structVarName, structVarPtrVersion, True, False)
                print_body_line('set = true;', injectionCpp, 3)
            else:
                print_equals_and_set(injectionCpp, 3, structVarName, structVarPtrVersion, False, True)
            print_body_line('}', injectionCpp, 2) # close generated if statement
    else:
        sansPtr = enum_name[:-4]
        sansPtrStruct = structVarName[:-3]
        if sansPtr in enum_name_to_type_dict:
            print_body_line('else if (other.m_type == %s)' % sansPtr, injectionCpp, 2)
            print_body_line('{', injectionCpp, 2)
            if arg_type[:4] == 'nvml':
                print_memcpy(injectionCpp, 3, structVarName, sansPtrStruct, False, True)
                print_body_line('set = true;', injectionCpp, 3)
            else:
                print_equals_and_set(injectionCpp, 3, structVarName, sansPtrStruct, True, False)
            print_body_line('}', injectionCpp, 2)

    if enum_name == 'INJECTION_UINT':
        print_body_line('else if (other.m_type == INJECTION_INT && other.m_value.i > 0)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('this->m_value.ui = other.m_value.i;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
        print_body_line('else if (other.m_type == INJECTION_INT_PTR && *other.m_value.iPtr > 0)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('this->m_value.ui = *other.m_value.iPtr;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
    elif enum_name == 'INJECTION_UINT_PTR':
        print_body_line('else if (other.m_type == INJECTION_INT && other.m_value.i > 0)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('*this->m_value.uiPtr = other.m_value.i;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
        print_body_line('else if (other.m_type == INJECTION_INT_PTR && *other.m_value.iPtr > 0)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('*this->m_value.uiPtr = *other.m_value.iPtr;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
    elif enum_name == 'INJECTION_INT':
        print_body_line('else if (other.m_type == INJECTION_UINT && other.m_value.ui <= INT_MAX)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('this->m_value.i = other.m_value.ui;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
        print_body_line('else if (other.m_type == INJECTION_UINT_PTR && *other.m_value.uiPtr <= INT_MAX)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('this->m_value.i = *other.m_value.uiPtr;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
    elif enum_name == 'INJECTION_INT_PTR':
        print_body_line('else if (other.m_type == INJECTION_UINT && other.m_value.ui <= INT_MAX)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('*this->m_value.iPtr = other.m_value.ui;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)
        print_body_line('else if (other.m_type == INJECTION_UINT_PTR && *other.m_value.uiPtr <= INT_MAX)', injectionCpp, 2)
        print_body_line('{', injectionCpp, 2)
        print_body_line('*this->m_value.iPtr = *other.m_value.uiPtr;', injectionCpp, 3)
        print_body_line('set = true;', injectionCpp, 3)
        print_body_line('}', injectionCpp, 2)

    print_body_line('break;', injectionCpp, 2)
    print_body_line('}', injectionCpp, 1) # close case block 

#######################################################################################    
def write_injection_argument_cpp(enum_name_to_type_dict, output_dir, arg_type_dict):
    injectionCppPath = '%s/%s' % (output_dir, INJECTION_CPP_PATH)
    with open(injectionCppPath, 'w') as injectionCpp:
        injectionCpp.write(AUTO_GENERATED_NOTICE)
        injectionCpp.write('#include <%s>\n' % INJECTION_ARGUMENT_HEADER)
        injectionCpp.write('#include <limits.h>\n')
        injectionCpp.write('#include <cstring>\n\n\n')

        injectionCpp.write('nvmlReturn_t InjectionArgument::SetValueFrom(const InjectionArgument &other)\n{\n')
        print_body_line('bool set = false;\n', injectionCpp, 0)
        print_body_line('if (other.IsEmpty())', injectionCpp, 0)
        print_body_line('{', injectionCpp, 0)
        print_body_line('return NVML_ERROR_NOT_FOUND;', injectionCpp, 1)
        print_body_line('}', injectionCpp, 0)
        print_body_line('switch (this->m_type)', injectionCpp, 0)
        print_body_line('{', injectionCpp, 0)

        for enum_name in enum_name_to_type_dict:
            write_case_entry(enum_name, enum_name_to_type_dict, injectionCpp, arg_type_dict)

        write_string_case_entry(injectionCpp)

        print_body_line('default:', injectionCpp, 1)
        print_body_line('break;', injectionCpp, 2)
        
        print_body_line('}', injectionCpp, 0)
        print_body_line('if (set)', injectionCpp, 0)
        print_body_line('{', injectionCpp, 0)
        print_body_line('return NVML_SUCCESS;', injectionCpp, 1)
        print_body_line('}', injectionCpp, 0)
        print_body_line('else', injectionCpp, 0)
        print_body_line('{', injectionCpp, 0)
        print_body_line('return NVML_ERROR_INVALID_ARGUMENT;', injectionCpp, 1)
        print_body_line('}', injectionCpp, 0)
        injectionCpp.write('}\n')

#######################################################################################
def write_injection_structs_header(all_argument_types, output_dir):
    injection_structs_path = "%s/%s" % (output_dir, INJECTION_STRUCTS_PATH)
    enum_dict = {}
    enum_name_to_type_dict = {}
    arg_type_dict = {}

    with open(injection_structs_path, 'w') as injectionStructs:
        injectionStructs.write(AUTO_GENERATED_NOTICE)
        injectionStructs.write('#pragma once\n\n')
        injectionStructs.write('#include <nvml.h>\n')
        
        # write the union for simple value types
        injectionStructs.write('typedef union\n{\n')
        for arg_type in all_argument_types:
            arg_type_name, isPtr = transform_arg_type(arg_type, arg_type_dict)
            if isPtr:
                print_body_line('%s%s;' % (arg_type, arg_type_name), injectionStructs, 0)
            else:
                print_body_line('%s %s;' % (arg_type, arg_type_name), injectionStructs, 0)
        injectionStructs.write('} simpleValue_t;\n\n')

        # write the enum for the types
        injectionStructs.write('typedef enum injectionArg_enum\n{\n')
        index = 0
        NAME_SPACE_LEN = 0
        for arg_type in all_argument_types:
            enum_name = get_enum_name(arg_type)
            enum_dict[arg_type] = enum_name
            enum_name_to_type_dict[enum_name] = arg_type
            if len(enum_name) > NAME_SPACE_LEN:
                NAME_SPACE_LEN = len(enum_name)

        for arg_type in all_argument_types:
            enum_name = enum_dict[arg_type]
            print_body_line("%s = %d," % (enum_name.ljust(NAME_SPACE_LEN), index), injectionStructs, 0)
            index += 1
        print_body_line('%s = %d,' % ("INJECTION_STRING".ljust(NAME_SPACE_LEN), index), injectionStructs, 0)
        print_body_line(INJECTION_ARG_COUNT_STR, injectionStructs, 0)
        injectionStructs.write('} injectionArgType_t;\n\n')

        injectionStructs.write('typedef struct\n{\n')
        print_body_line('simpleValue_t value;', injectionStructs, 0)
        print_body_line('injectionArgType_t type;', injectionStructs, 0)
        injectionStructs.write('} injectNvmlVal_t;\n\n')

    return enum_dict, enum_name_to_type_dict, arg_type_dict

#######################################################################################
def write_injection_argument_header(all_argument_types, output_dir):
    all_argument_types.sort()
    injectionHeaderPath = '%s/%s' % (output_dir, INJECTION_ARGUMENT_PATH)

    enum_dict, enum_name_to_type_dict, arg_type_dict = write_injection_structs_header(all_argument_types, output_dir)

    with open(injectionHeaderPath, 'w') as injectionHeader:

        injectionHeader.write(AUTO_GENERATED_NOTICE)
        injectionHeader.write('#pragma once\n\n')
        injectionHeader.write('#include <cstring>\n')
        injectionHeader.write('#include <nvml.h>\n')
        injectionHeader.write('#include <string>\n\n')
        injectionHeader.write('#include "%s"\n\n' % INJECTION_STRUCTS_NAME)


        injectionHeader.write('class InjectionArgument\n{\nprivate:\n')
        print_body_line('injectionArgType_t m_type;', injectionHeader, 0)
        print_body_line('simpleValue_t m_value;', injectionHeader, 0)
        print_body_line('std::string m_str;\n', injectionHeader, 0)
        injectionHeader.write('public:\n')
        print_body_line('InjectionArgument()', injectionHeader, 0)
        print_body_line(': m_type(%s)' % INJECTION_ARG_COUNT_STR, injectionHeader, 1)
        print_body_line('{', injectionHeader, 0)
        print_body_line('Clear();', injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        print_body_line('InjectionArgument(const injectNvmlVal_t &value)', injectionHeader, 0)
        print_body_line(': m_type(value.type)', injectionHeader, 1)
        print_body_line(', m_value(value.value)', injectionHeader, 1)
        print_body_line('{}\n', injectionHeader, 0)
        print_body_line('/**', injectionHeader, 0)
        print_body_line(' * SetValueFrom - Sets this injection argument based other\'s value', injectionHeader, 0)
        print_body_line(' * @param other - the InjectionArgument whose value we flexibly copy if possible.', injectionHeader, 0)
        print_body_line(' *', injectionHeader, 0)
        print_body_line(' * @return 0 if we could set from other\'s value, 1 if incompatible', injectionHeader, 0)
        print_body_line(' **/', injectionHeader, 0)
        print_body_line('nvmlReturn_t SetValueFrom(const InjectionArgument &other);\n', injectionHeader, 0)
        print_body_line('injectionArgType_t GetType() const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('return m_type;', injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        print_body_line('simpleValue_t GetSimpleValue() const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('return m_value;', injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        print_body_line('void Clear()', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('memset(&m_value, 0, sizeof(m_value));', injectionHeader, 1)
        print_body_line('}', injectionHeader, 0)
        print_body_line('int Compare(const InjectionArgument &other) const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('if (m_type < other.m_type)', injectionHeader, 1)
        print_body_line('{', injectionHeader, 1)
        print_body_line('return -1;', injectionHeader, 2)
        print_body_line('}', injectionHeader, 1)
        print_body_line('else if (m_type > other.m_type)', injectionHeader, 1)
        print_body_line('{', injectionHeader, 1)
        print_body_line('return 1;', injectionHeader, 2)
        print_body_line('}', injectionHeader, 1)
        print_body_line('else', injectionHeader, 1)
        print_body_line('{', injectionHeader, 1)
        print_body_line('if (m_type == INJECTION_STRING)', injectionHeader, 2)
        print_body_line('{', injectionHeader, 2)
        print_body_line('if (m_str < other.m_str)', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        print_body_line('return -1;', injectionHeader, 4)
        print_body_line('}', injectionHeader, 3)
        print_body_line('else if (m_str > other.m_str)', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        print_body_line('return 1;', injectionHeader, 4)
        print_body_line('}', injectionHeader, 3)
        print_body_line('else', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        print_body_line('return 0;', injectionHeader, 4)
        print_body_line('}', injectionHeader, 3)
        print_body_line('}', injectionHeader, 2)
        print_body_line('else', injectionHeader, 2)
        print_body_line('{', injectionHeader, 2)
        print_body_line('switch (m_type)', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        for arg_type in all_argument_types:
            typeTuple = arg_type_dict[arg_type]
            enumName = enum_dict[arg_type]
            print_body_line('case %s:' % enumName, injectionHeader, 4)
            print_body_line('{', injectionHeader, 4)
            if typeTuple[1]:
                if enumName == 'INJECTION_CHAR_PTR' or enumName == 'INJECTION_CONST_CHAR_PTR':
                    print_body_line('return strcmp(m_value.%s, other.m_value.%s);' % (typeTuple[0], typeTuple[0]), injectionHeader, 5)
                elif len(typeTuple[0]) <= 12:
                    print_body_line('return memcmp(m_value.%s, other.m_value.%s, sizeof(*m_value.%s));' % (typeTuple[0], typeTuple[0], typeTuple[0]), injectionHeader, 5)
                elif len(typeTuple[0]) <= 15:
                    print_body_line('return memcmp(', injectionHeader, 5)
                    print_body_line('m_value.%s, other.m_value.%s, sizeof(*m_value.%s));' % (typeTuple[0], typeTuple[0], typeTuple[0]), injectionHeader, 6)
                else:
                    print_body_line('return memcmp(m_value.%s,' % typeTuple[0], injectionHeader, 5)
                    print_body_line('  other.m_value.%s,' % typeTuple[0], injectionHeader, 8)
                    print_body_line('  sizeof(*m_value.%s));' % typeTuple[0], injectionHeader, 8)
            else:
                print_body_line('if (m_value.%s < other.m_value.%s)'  % (typeTuple[0], typeTuple[0]), injectionHeader, 5)
                print_body_line('{', injectionHeader, 5)
                print_body_line('return -1;', injectionHeader, 6)
                print_body_line('}', injectionHeader, 5)
                print_body_line('else if (m_value.%s > other.m_value.%s)' % (typeTuple[0], typeTuple[0]), injectionHeader, 5)
                print_body_line('{', injectionHeader, 5)
                print_body_line('return 1;', injectionHeader, 6)
                print_body_line('}', injectionHeader, 5)
                print_body_line('else', injectionHeader, 5)
                print_body_line('{', injectionHeader, 5)
                print_body_line('return 0;', injectionHeader, 6)
                print_body_line('}', injectionHeader, 5)
            print_body_line('break; // NOT REACHED', injectionHeader, 5)
            print_body_line('}', injectionHeader, 4)
        print_body_line('default:', injectionHeader, 4)
        print_body_line('break;', injectionHeader, 5)
        print_body_line('}', injectionHeader, 3)
        print_body_line('}', injectionHeader, 2)
        print_body_line('}', injectionHeader, 1)
        print_body_line('return true;', injectionHeader, 1)
        print_body_line('}', injectionHeader, 0)
        print_body_line('bool operator<(const InjectionArgument &other) const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('return this->Compare(other) == -1;', injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        print_body_line('bool operator==(const InjectionArgument &other) const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('return this->Compare(other) == 0;', injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        print_body_line('bool IsEmpty() const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('return m_type == %s;' % INJECTION_ARG_COUNT_STR, injectionHeader, 1)
        print_body_line('}\n', injectionHeader, 0)
        for arg_type in all_argument_types:
            # Write constructor
            typeTuple = arg_type_dict[arg_type]
            if is_pointer_type(arg_type):
                print_body_line('InjectionArgument(%s%s)' % (arg_type, typeTuple[0]), injectionHeader, 0)
            else:
                print_body_line('InjectionArgument(%s %s)' % (arg_type, typeTuple[0]), injectionHeader, 0)
            print_body_line(': m_type(%s)' % (enum_dict[arg_type]), injectionHeader, 1)
            print_body_line('{', injectionHeader, 0)
            print_body_line('memset(&m_value, 0, sizeof(m_value));', injectionHeader, 1)
            print_body_line('m_value.%s = %s;' % (typeTuple[0], typeTuple[0]), injectionHeader, 1)
            print_body_line('}', injectionHeader, 0)

            # Write As* function
            if typeTuple[1]:
                print_body_line('%sAs%s() const' % (arg_type, typeTuple[2]), injectionHeader, 0)
            else:
                print_body_line('%s As%s() const' % (arg_type, typeTuple[2]), injectionHeader, 0)
            print_body_line('{', injectionHeader, 0)
            print_body_line('return m_value.%s;' % typeTuple[0], injectionHeader, 1)
            print_body_line('}\n', injectionHeader, 0)

        print_body_line('InjectionArgument(const std::string &val)', injectionHeader, 0)
        print_body_line(': m_type(INJECTION_STRING)', injectionHeader, 1)
        print_body_line(', m_str(val)', injectionHeader, 1)
        print_body_line('{', injectionHeader, 0)
        print_body_line('memset(&m_value, 0, sizeof(m_value));', injectionHeader, 1)
        print_body_line('}', injectionHeader, 0)
        print_body_line('std::string AsString() const', injectionHeader, 0)
        print_body_line('{', injectionHeader, 0)
        print_body_line('switch (m_type)', injectionHeader, 1)
        print_body_line('{', injectionHeader, 1)
        print_body_line('case INJECTION_STRING:', injectionHeader, 2)
        print_body_line('{', injectionHeader, 2)
        print_body_line('return m_str;', injectionHeader, 3)
        print_body_line('}', injectionHeader, 2)
        print_body_line('break;', injectionHeader, 3)
        print_body_line('case INJECTION_CHAR_PTR:', injectionHeader, 2)
        print_body_line('{', injectionHeader, 2)
        print_body_line('if (m_value.str != nullptr)', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        print_body_line('return std::string(m_value.str);', injectionHeader, 4)
        print_body_line('}', injectionHeader, 3)
        print_body_line('break;', injectionHeader, 3)
        print_body_line('}', injectionHeader, 2)
        print_body_line('case INJECTION_CONST_CHAR_PTR:', injectionHeader, 2)
        print_body_line('{', injectionHeader, 2)
        print_body_line('if (m_value.const_str != nullptr)', injectionHeader, 3)
        print_body_line('{', injectionHeader, 3)
        print_body_line('return std::string(m_value.const_str);', injectionHeader, 4)
        print_body_line('}', injectionHeader, 3)
        print_body_line('break;', injectionHeader, 3)
        print_body_line('}', injectionHeader, 2)
        print_body_line('default:', injectionHeader, 2)
        print_body_line('break;', injectionHeader, 3) 
        print_body_line('}', injectionHeader, 1)
        print_body_line('return "";', injectionHeader, 1)
        print_body_line('}', injectionHeader, 0)

        injectionHeader.write('};\n')

    write_injection_argument_cpp(enum_name_to_type_dict, output_dir, arg_type_dict)

    return enum_dict

#######################################################################################    
def get_enum_from_arg_type(enum_dict, arg_type):
    check_type = get_true_arg_type(arg_type)
    return enum_dict[check_type]

#######################################################################################    
def write_key_file(output_dir):
    key_file_path = "%s/%s" % (output_dir, KEY_LIST_PATH)
    with open(key_file_path, 'w') as key_file:
        key_file.write(AUTO_GENERATED_NOTICE)
        key_file.write("// clang-format off\n")
        for key in g_key_to_function:
            key_file.write("const char *INJECTION_%s_KEY = \"%s\"; // Function name(s): %s\n" % (key.upper(), key, g_key_to_function[key]))

    key_header_path = "%s/%s" % (output_dir, KEY_LIST_HEADER_PATH)
    with open(key_header_path, 'w') as key_header:
        key_header.write(AUTO_GENERATED_NOTICE)
        for key in g_key_to_function:
            key_header.write("extern const char *INJECTION_%s_KEY;\n" % (key.upper()))


#######################################################################################
def write_linux_defs(output_dir, func_dict):
    linux_defs_path = '%s/%s' % (output_dir, LINUX_DEFS_PATH)
    manually_written_functions = [
        'injectionNvmlInit',
        'nvmlDeviceSimpleInject',
        'nvmlDeviceInjectExtraKey',
        'nvmlDeviceInjectFieldValue',
    ]

    with open(linux_defs_path, 'w') as linux_defs_file:
        linux_defs_file.write('{\n    global:\n')
        for funcname in manually_written_functions:
            print_body_line('%s;' % funcname, linux_defs_file, 1)
        for funcname in func_dict:
            print_body_line('%s;' % funcname, linux_defs_file, 1)

        print_body_line('extern "C++" {', linux_defs_file, 1)
        print_body_line('_ZTI*;', linux_defs_file, 2)
        print_body_line('_ZTS*;', linux_defs_file, 2)
        print_body_line('};\n', linux_defs_file, 1)
        print_body_line('local:', linux_defs_file, 0)
        print_body_line('*;', linux_defs_file, 1)
        print_body_line('extern "C++" {', linux_defs_file, 1)
        print_body_line('*;', linux_defs_file, 2)
        print_body_line('};', linux_defs_file, 1)
        linux_defs_file.write('};')

####################################################################################### 
def parse_entry_points_contents(contents, output_dir):
    function_dict = {}
    all_argument_types = []
    all_functions = AllFunctions()
    all_functypes = AllFunctionTypes()

    entry_points = contents.split('NVML_ENTRY_POINT')
    total_funcs = 0
    auto_generated = 0
    not_generated = []

    outputStubPath = '%s/%s' % (output_dir, STUB_PATH)

    with open(outputStubPath, 'w') as stub_file:
        write_stub_file_header(stub_file)

        first = True
        for entry_point in entry_points:
            funcname, arg_list, arg_names = get_function_signature(entry_point, first)
            first = False
            if funcname and arg_list:
                fi = FunctionInfo(funcname, arg_list, arg_names)
                all_functions.AddFunction(fi)

        #all_functions.RemoveEarlierVersions()

        for funcname in all_functions.func_dict:
            funcinfo = all_functions.func_dict[funcname]
            all_functypes.AddFunctionType(funcname, funcinfo)

        for funcname in all_functions.func_dict:
            funcinfo = all_functions.func_dict[funcname]
            build_argument_type_list(funcinfo.GetArgumentList(), all_argument_types)
            if write_function(stub_file, funcinfo, all_functypes):
                auto_generated = auto_generated + 1
            else:
                not_generated.append(funcname)
            total_funcs = total_funcs + 1
            function_dict[funcname] = arg_list

        stub_file.write("#ifdef __cplusplus\n}\n")
        stub_file.write("#endif\n")
        stub_file.write('// END nvml_generated_stubs')

    write_key_file(output_dir)
    write_linux_defs(output_dir, all_functions.func_dict)

    enum_dict = write_injection_argument_header(all_argument_types, output_dir)
    write_declarations_file(all_functypes.GetAllFunctionDeclarations(), output_dir)
    #write_function_info(enum_dict, function_dict, output_dir)

    print("I was able to generate the injection body for %d of %d functions" % (auto_generated, total_funcs))
    with open('ungenerated.txt', 'w') as ungenerated:
        ungenerated.write('The following were not auto-generated:\n\n')
        for ungen in not_generated:
            ungenerated.write("%s\n" % ungen)

#######################################################################################    
def parse_entry_points(inputPath, output_dir):
    with open(inputPath, 'r') as entryFile:
        contents = entryFile.read()
        parse_entry_points_contents(contents, output_dir)

#######################################################################################    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input-file', default='sdk/nvml/entry_points.h', dest='inputPath')
    parser.add_argument('-o', '--output-dir', default='.', dest='outputDir')
    args = parser.parse_args()
    parse_entry_points(args.inputPath, args.outputDir)

if __name__ == '__main__':
    main()


# TODO: delete this once we're sure we aren't using it
#######################################################################################    
def write_function_info(enum_dict, function_dict, output_dir):
    function_info_path = '%s/%s' % (output_dir, FUNCTION_INFO_PATH)
    function_args_dict = {}
    with open(function_info_path, 'w') as funcInfoFile:
        funcInfoFile.write(AUTO_GENERATED_NOTICE)
        funcInfoFile.write('#include <nvml.h>\n')
        funcInfoFile.write('#include <%s>\n\n' % INJECTION_ARGUMENT_HEADER)

        funcInfoFile.write('functionInfo_t functionInfos[] = {\n')
        for functionName in function_dict:
            argument_types = get_argument_types_from_argument_list(function_dict[functionName])
            function_args_dict[functionName] = [function_dict[functionName], argument_types]
            type_list_str = ''
            argCount = 0
            for arg_type in argument_types:
                argCount += 1
                if type_list_str == '':
                    type_list_str = '%s' % get_enum_from_arg_type(enum_dict, arg_type)
                else:
                    type_list_str += ', %s' % get_enum_from_arg_type(enum_dict, arg_type)

            type_list_str += '    '

            for i in range(0, MAX_NVML_ARGS - argCount):
                type_list_str += ', %s' % INJECTION_ARG_COUNT_STR
            print_body_line('{ \"%s\", %d, { %s } },\n' % (functionName, argCount, type_list_str), funcInfoFile, 0)

        funcInfoFile.write('};\n\n')

